{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)\n",
    "\n",
    "Install the wormpose package if needed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade wormpose"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If using Google Colab, please **restart the runtime** after installing the package: menu \"Runtime\" > \"Restart runtime\".\n",
    "\n",
    "You can also select a GPU node for faster training: menu \"Runtime\" > \"Change Runtime Type\" > select \"GPU\" in the menu \"Hardware Accelerator\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first download some utility functions to display images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/iteal/wormpose/main/examples/ipython_utils.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download sample data\n",
    "\n",
    "The sample data is composed of a set of images and a h5 file containing the features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data_root = 'wormpose_data'\n",
    "import os, shutil\n",
    "if os.path.exists(sample_data_root):\n",
    "    shutil.rmtree(sample_data_root)\n",
    "os.mkdir(sample_data_root)\n",
    "!git clone https://github.com/iteal/wormpose_data.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set inputs\n",
    "\n",
    "We load the sample data dataset using the \"sample_data\" default loader, and set the path of the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from wormpose.config import default_paths\n",
    "from wormpose.dataset.loader import load_dataset\n",
    "\n",
    "# We have different loaders for different datasets, we use \"sample_data\" for the tutorial data,\n",
    "# replace with \"tierpsy\" for Tierpsy tracker data, or with your custom dataset loader name\n",
    "dataset_loader = \"sample_data\"\n",
    "\n",
    "# Set the path to the dataset,\n",
    "# for Tierpsy tracker data this will be the root path of a folder containing subfolders for each videos\n",
    "dataset_path = \"wormpose_data/datasets/sample_data\"\n",
    "\n",
    "print(f\"Using the default dataset loader: \\'{dataset_loader}\\', to load the sample dataset images and labels,\\n from the folder \\'{dataset_path}\\'.\\n\")\n",
    "dataset_root_name = os.path.basename(os.path.normpath(dataset_path))\n",
    "project_dir = os.path.join(default_paths.WORK_DIR, dataset_root_name)\n",
    "\n",
    "# Set if the worm is lighter than the background in the image\n",
    "# in the sample data, the worm is darker so we set this variable to False\n",
    "worm_is_lighter = False\n",
    "\n",
    "# This function loads the dataset\n",
    "# optional fields: there is an optional resize parameter to resize the images\n",
    "# also you can select specific videos from the dataset instead of loading them all\n",
    "dataset = load_dataset(dataset_loader, dataset_path, worm_is_lighter=worm_is_lighter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the raw dataset images\n",
    "\n",
    "First, we simply display the first 100 frames of the first video the dataset. These are the raw dataset images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "MAX_FRAMES = 100\n",
    "img_viewer = ImagesViewer()\n",
    "\n",
    "video_name = dataset.video_names[0]\n",
    "with dataset.frames_dataset.open(video_name) as frames:\n",
    "    for frame in frames[:MAX_FRAMES]:\n",
    "        img_viewer.add_image(frame)\n",
    "        \n",
    "img_viewer.view_as_slider()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the synthetic images\n",
    "\n",
    "The postures model generates worm postures, that we can use to draw a synthetic image representing that posture.\n",
    "You can run the cell below to visualize a small sample of such generated images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import tempfile\n",
    "from wormpose.demo.synthetic_simple_visualizer import SyntheticSimpleVisualizer\n",
    "from ipython_utils import ImagesViewer, display_as_slider\n",
    "\n",
    "synth_viz = SyntheticSimpleVisualizer(dataset_loader,\n",
    "                                      dataset_path, \n",
    "                                      worm_is_lighter=worm_is_lighter).generate()\n",
    "img_viewer, img_viewer_plot = ImagesViewer(), ImagesViewer()\n",
    "num_images = 50\n",
    "\n",
    "print(f\"Viewing {num_images} synthetic images.\")\n",
    "tempdir = tempfile.gettempdir()\n",
    "for i in range(num_images):\n",
    "\n",
    "    synth_image, theta = next(synth_viz)\n",
    "\n",
    "    plt.plot(theta)\n",
    "    plt.ylabel(\"theta (rad)\")\n",
    "    plt.xlabel(\"body segment\")\n",
    "    plot_path = os.path.join(tempdir, f\"theta_{i}.png\")\n",
    "    plt.savefig(plot_path)\n",
    "    plt.clf()\n",
    "    img_viewer_plot.add_image_filename(plot_path)\n",
    "\n",
    "    img_viewer.add_image(synth_image)\n",
    "\n",
    "display_as_slider(img_viewer, img_viewer_plot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the frame preprocessing\n",
    "\n",
    "You can run the cell below to visualize a sample of the real images after they have been processed : they are cropped (or extended) to the same size which corresponds to the average worm length, and the background and non worm objects pixels are set to a uniform color. In that way, they become visually similar to the synthetic images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wormpose.demo.real_simple_visualizer import RealSimpleVisualizer\n",
    "from ipython_utils import ImagesViewer, display_as_slider\n",
    "\n",
    "viz = RealSimpleVisualizer(dataset_loader, \n",
    "                           dataset_path, \n",
    "                           worm_is_lighter=worm_is_lighter).generate()\n",
    "orig_img_viewer, processed_img_viewer = ImagesViewer(), ImagesViewer()\n",
    "\n",
    "max_viz = 100\n",
    "print(f\"Displaying the first {max_viz} frames : original and processed.\")\n",
    "\n",
    "for _ in range(max_viz):\n",
    "    orig_image, processed_image = next(viz)\n",
    "    orig_img_viewer.add_image(orig_image)\n",
    "    processed_img_viewer.add_image(processed_image)\n",
    "\n",
    "display_as_slider(orig_img_viewer, processed_img_viewer)\n",
    "\n",
    "print(f\"The processed images are all set to the size: ({processed_image.shape[0]}px, {processed_image.shape[1]}px).\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calibration\n",
    "\n",
    "We use a pixel comparison function to compare a real labelled image to its reconstruction as a synthetic image, which assigns a score between 0 (worse) and 1 (perfect reconstruction).\n",
    "This function will be used to evaluate the predictions by comparing them to the original image, and filtering bad results with a threshold.\n",
    "\n",
    "To decide on which threshold to use, we can evaluate this image score function on real labelled images from the dataset. The resulting distribution of scores should be close to 1 if the features are correct. We can visualize the real and the synthetic image that was used for the scoring. Here for example we display 5 examples that represent the range of results from worst to best."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "from wormpose.commands import calibrate\n",
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "video_name, result_file = next(calibrate(dataset_loader, \n",
    "                                         dataset_path, \n",
    "                                         worm_is_lighter=worm_is_lighter,\n",
    "                                         save_images=True))\n",
    "\n",
    "VIEW_SCORES = 5\n",
    "\n",
    "img_viewer = ImagesViewer()\n",
    "with h5py.File(result_file, \"r\") as f:\n",
    "    scores = f['scores'][()]\n",
    "    real_images = f['real_images']\n",
    "    synthetic_images = f['synth_images']\n",
    "\n",
    "    plt.hist(scores, bins=np.arange(0.5, 1, 0.01), \n",
    "             weights=np.ones_like(scores)/len(scores))\n",
    "    plt.xlabel(\"image similarity\")\n",
    "    plt.title(f\"Distribution of image scores for known frames\\n (video: {video_name})\")\n",
    "    plt.show()\n",
    "    \n",
    "    sorted_scores = np.argsort(scores)\n",
    "    step = int(len(sorted_scores)/VIEW_SCORES)\n",
    "    sorted_selection_index = [sorted_scores[0]] + sorted_scores[step:-step:step].tolist() + [sorted_scores[-1]]\n",
    "\n",
    "    for index in sorted_selection_index:\n",
    "        im = np.hstack([real_images[index], synthetic_images[index]])\n",
    "        img_viewer.add_image(im)\n",
    "    \n",
    "img_viewer.view_as_list(legends=scores[sorted_selection_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the training and evaluation dataset\n",
    "We now build the training dataset which contain images such as above, saved as a binary file \".tfrecord\". We also build an evaluation dataset. \n",
    "\n",
    "We create a small training set of 1000 images for this tutorial. For a more representative training set, increase the value of num_train_samples (the default value is 500k), but the generation will take more time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wormpose.commands import generate\n",
    "from ipywidgets import FloatProgress\n",
    "from IPython.display import display\n",
    "\n",
    "fp = FloatProgress(min=0., max=1.)\n",
    "display(fp)\n",
    "\n",
    "gen_progress = generate(dataset_loader,\n",
    "                        dataset_path,\n",
    "                        worm_is_lighter=worm_is_lighter,\n",
    "                        num_train_samples=1000)\n",
    "for progress_value in gen_progress:\n",
    "    fp.value = progress_value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check generated tfrecord files\n",
    "\n",
    "We check that the tfrecord files have been generated successfully by viewing the first few images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "import os\n",
    "from wormpose.config import default_paths\n",
    "from wormpose.machine_learning import tfrecord_file\n",
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "\n",
    "def view_tfrecord(filename, theta_dims=100, max_viz=100):\n",
    "    img_viewer = ImagesViewer()\n",
    "    for index, record in enumerate(tfrecord_file.read(filename, theta_dims)):\n",
    "        if index >= max_viz:\n",
    "            break\n",
    "        image_data = record[0].numpy()\n",
    "        img_viewer.add_image(image_data)   \n",
    "    print(f\"Reading: \\'{filename}\\' ({index} first frames)\")\n",
    "\n",
    "    img_viewer.view_as_slider() \n",
    "\n",
    "train_records = list(sorted(glob(os.path.join(project_dir, \n",
    "                                              default_paths.TRAINING_DATA_DIR,\n",
    "                                              default_paths.SYNTH_TRAIN_DATASET_NAMES.format(index='*')))))\n",
    "print(f\"Training tfrecord files: {len(train_records)} files.\")\n",
    "if len(train_records) > 0 :\n",
    "    view_tfrecord(train_records[0])\n",
    "eval_record = list(glob(os.path.join(project_dir, \n",
    "                                     default_paths.TRAINING_DATA_DIR, \n",
    "                                     default_paths.REAL_EVAL_DATASET_NAMES.format(index='*'))))\n",
    "if len(eval_record) > 0 :\n",
    "    view_tfrecord(eval_record[0])    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train\n",
    "\n",
    "We train the network on the generated data.\n",
    "\n",
    "We only train on 10 epochs for this tutorial, increase the number of epochs for better results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from wormpose.commands import train\n",
    "\n",
    "if tf.test.gpu_device_name() == '':\n",
    "    print(\"Warning, no GPU available for training, this will be very slow.\")\n",
    "    \n",
    "train(dataset_path, epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict\n",
    "\n",
    "We can now predict the full video. We use a score threshold of 0.7 to discard wrong results (see calibration to choose the value of the threshold)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wormpose.commands import predict\n",
    "\n",
    "use_pretrained_network = True\n",
    "if use_pretrained_network: # already trained model for \"sample_data\" only\n",
    "    model_path = os.path.join('wormpose_data', 'models', 'sample_data', 'trained_model.hdf5')\n",
    "else: # will use the default path for the model that was trained in the previous cell\n",
    "    model_path = None\n",
    "\n",
    "predict(dataset_path=dataset_path, \n",
    "        score_threshold=0.7,\n",
    "        model_path=model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We visualize the result by drawing the posture skeleton on top of the original image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wormpose.commands import visualize\n",
    "\n",
    "visualize(dataset_path, draw_original=False)\n",
    "!find -name 'images_results.zip' -exec sh -c 'unzip -o -d \"${1%.*}\" \"$1\"' _ {} \\;  > /dev/null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob\n",
    "from wormpose.config import default_paths\n",
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "MAX_FRAMES = 1000\n",
    "\n",
    "img_filenames = sorted(glob.glob(os.path.join(project_dir, default_paths.RESULTS_DIR, '*','*','*.png')))[:MAX_FRAMES]\n",
    "img_viewer = ImagesViewer()\n",
    "for filename in img_filenames:\n",
    "    img_viewer.add_image_filename(filename)\n",
    "\n",
    "img_viewer.view_as_list(legends=range(len(img_filenames)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
